from __future__ import annotations

from typing import TYPE_CHECKING, Any

import numpy as np
import numpy.typing as npt

from .._serializable import Serializable

if TYPE_CHECKING:
    from collections.abc import Callable


class Masker(Serializable):
    """This is the superclass of all maskers."""

    # Subclasses should define these attributes
    shape: tuple[int | None, int] | Callable[..., tuple[int | None, int]]
    clustering: npt.NDArray[Any] | Callable[..., Any] | None

    def __call__(self, mask: bool | npt.NDArray[Any], *args: Any) -> Any:
        """Maskers are callable objects that accept the same inputs as the model plus a binary mask."""

    def _standardize_mask(self, mask: bool | npt.NDArray[Any], *args: Any) -> npt.NDArray[np.bool_]:
        """This allows users to pass True/False as short hand masks."""
        if mask is True or mask is False:
            if callable(self.shape):
                shape = self.shape(*args)
            else:
                shape = self.shape

            if mask is True:
                return np.ones(shape[1], dtype=bool)
            return np.zeros(shape[1], dtype=bool)
        return mask  # type: ignore[return-value]
